import os
import numpy as np
import torchvision
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import random_split
import warnings


def get_dataloaders_cifar(settings):
    # Prepare dataset
    transform_train = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    dataset_location = os.path.join(settings.datasets_path, settings.dataset)

    if settings.dataset == "cifar10":
        train_set = torchvision.datasets.CIFAR10(
            root=dataset_location,
            train=True,
            download=False,
            transform=transform_train,
        )
        val_set = torchvision.datasets.CIFAR10(
            root=dataset_location,
            train=True,
            download=False,
            transform=transform_test,
        )
        test_set = torchvision.datasets.CIFAR10(
            root=dataset_location,
            train=False,
            download=False,
            transform=transform_test,
        )

    elif settings.dataset == "cifar100":
        train_set = torchvision.datasets.CIFAR100(
            root=dataset_location,
            train=True,
            download=False,
            transform=transform_train,
        )
        val_set = torchvision.datasets.CIFAR100(
            root=dataset_location,
            train=True,
            download=False,
            transform=transform_test,
        )
        test_set = torchvision.datasets.CIFAR100(
            root=dataset_location,
            train=False,
            download=False,
            transform=transform_test,
        )

    else:
        warnings.warn("Dataset is not listed")

    # Create train-val split. Validation set is 20% of entire train set
    num_train = len(train_set)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(settings.val_set_perc * num_train))

    train_idx, val_idx = indices[split:], indices[:split]
    # if settings.get_val_temp > 0:
    #     val_temp_dataset = torchvision.datasets.CIFAR10(
    #         root=dataset_location,
    #         train=True,
    #         download=False,
    #         transform=transform_test,
    #     )
    #     split = int(np.floor(settings.get_val_temp * split))
    #     val_idx, val_temp_idx = val_idx[split:], val_idx[:split]
    #     val_temp_sampler = SubsetRandomSampler(val_temp_idx)
    #     val_temp_loader = DataLoader(
    #         val_temp_dataset,
    #         batch_size=settings.batch_size,
    #         sampler=val_temp_sampler,
    #         num_workers=4,
    #     )
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    train_loader = DataLoader(
        train_set,
        batch_size=settings.batch_size,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=settings.batch_size,
        sampler=val_sampler,
        num_workers=4,
        pin_memory=True,
    )
    # if settings.test_corruptions == 1:
    #     dataset_location_corrupted = os.path.join(
    #         settings.location, "CIFAR-{}-C".format(str(settings.num_classes))
    #     )
    #     test_set.data = np.load(dataset_location_corrupted + "{}.npy".format(settings.corruption_type))
    #     test_set.targets = torch.LongTensor(
    #         np.load(dataset_location_corrupted + "labels.npy")
    #     )
    #     test_loader = DataLoader(
    #         test_set,
    #         batch_size=settings.batch_size,
    #         shuffle=False,
    #         num_workers=4,
    #         pin_memory=True,
    #     )
    # else:
    test_loader = DataLoader(
        test_set,
        batch_size=settings.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    # if settings.get_val_temp > 0:
    #     return train_loader, val_loader, val_temp_loader
    # else:
    return train_loader, val_loader, test_loader
